package com.yeskery.nut.extend.oauth;

import com.yeskery.nut.core.HttpHeader;
import com.yeskery.nut.core.RequestHandler;
import com.yeskery.nut.core.ResponseCode;
import com.yeskery.nut.core.ResponseNutException;
import com.yeskery.nut.extend.auth.AuthPlugin;
import com.yeskery.nut.extend.auth.Authenticate;
import com.yeskery.nut.util.StringUtils;

import java.util.function.Consumer;
import java.util.function.Function;

/**
 * OAuth拦截器
 * @author YESKERY
 * 2024/10/22
 */
public class OAuthInterceptor extends AuthPlugin {

    /** Bearer认证前缀 */
    public static final String BEARER_AUTH_PREFIX = "Bearer ";

    /**
     * 构造OAuth拦截器
     */
    public OAuthInterceptor() {
    }

    /**
     * 构造OAuth拦截器
     * @param accessTokenGenerator 访问令牌生成器
     * @param unauthorizedHandler 未授权处理器
     */
    public OAuthInterceptor(AccessTokenGenerator accessTokenGenerator, RequestHandler unauthorizedHandler) {
        this(accessTokenGenerator::validateAccessToken, unauthorizedHandler);
    }

    /**
     * 构造OAuth拦截器
     * @param authenticateFunction OAuth认证函数
     * @param unauthorizedHandler 未授权处理器
     */
    public OAuthInterceptor(Function<String, Boolean> authenticateFunction, RequestHandler unauthorizedHandler) {
        super((request, controller) -> {
            String value = request.getHeader(HttpHeader.AUTHORIZATION);
            if (StringUtils.isEmpty(value) || !value.startsWith(BEARER_AUTH_PREFIX) || value.length() == BEARER_AUTH_PREFIX.length()) {
                return false;
            }
            value = value.substring(BEARER_AUTH_PREFIX.length());
            return !StringUtils.isEmpty(value) && authenticateFunction.apply(value);
        }, unauthorizedHandler);
    }

    /**
     * 构造OAuth拦截器
     * @param accessTokenGenerator 访问令牌生成器
     */
    public OAuthInterceptor(AccessTokenGenerator accessTokenGenerator) {
        this(accessTokenGenerator, (request, response, execution) -> {throw new ResponseNutException(ResponseCode.UNAUTHORIZED);});
    }

    /**
     * 构造OAuth拦截器
     * @param authenticateFunction OAuth认证函数
     */
    public OAuthInterceptor(Function<String, Boolean> authenticateFunction) {
        this(authenticateFunction, (request, response, execution) -> {throw new ResponseNutException(ResponseCode.UNAUTHORIZED);});
    }

    /**
     * 创建认证函数
     * @param accessTokenGenerator 访问令牌生成器
     * @param accessTokenValueConsumer 访问令牌值处理器
     * @return 认证函数
     */
    protected Authenticate createAuthenticate(AccessTokenGenerator accessTokenGenerator, Consumer<String> accessTokenValueConsumer) {
        return (request, controller) -> {
            String value = request.getHeader(HttpHeader.AUTHORIZATION);
            if (StringUtils.isEmpty(value) || !value.startsWith(BEARER_AUTH_PREFIX) || value.length() == BEARER_AUTH_PREFIX.length()) {
                return false;
            }
            value = value.substring(BEARER_AUTH_PREFIX.length());
            boolean result = !StringUtils.isEmpty(value) && accessTokenGenerator.validateAccessToken(value);
            if (result) {
                if (accessTokenValueConsumer != null) {
                    accessTokenValueConsumer.accept(accessTokenGenerator.getAccessTokenValue(value));
                }
            }
            return result;
        };
    }
}
